Causal Representation Learning

Introduction

Representation learning aims to retrieve low-dimensional representations that summarize the high-dimensional input data including images, videos, or texts. The learned representations can be employed for various downstream tasks. However, traditional representation learning intrinsically relies on statistical correlations, which may incorporate spurious relations resulting from selection bias in data. To tackle such an issue, current methods aim to learn the representations from multi-domain data, which cannot guarantee the performance in a new domain and is restrictive for general applications. Hence we focus on learning the causal representations from single-domain data, which is more admissible in practice.

Causal representation learning, instead of relying on superficial correlations, employs a structural causal model (SCM) to capture the underlying data generation mechanism that encodes the intrinsic, stable, and interpretable causal relations in data. Hence, causal representations are robust to exogenous factor changes, invariant under distribution shifts, and hence generalize well to OOD settings. However, current causal representation learning methods are limited in several aspects. First, they often adopt merely the direct causes (also known as causal features) w.r.t label, ignoring the direct effects (anti-causal features) or other related features that are also predictive and discriminative. Moreover, some works assume causal sufficiency, ignoring the effects of the latent confounders that introduce spurious correlations between representations and the label.

To address these limitations, we propose a deep causal learning framework that augments the deep learning models with causal representation learning to learn a local causal graph, which consists of causal features that are discriminative, invariant, and interpretable for predicting the label.

Recent Works

Causal Markov Blanket Representation Learning for Out-of-distribution Generalization

This research addresses the poor out-of-distribution (OOD) generalization issue in the realm of machine learning and computer vision. Current methods aim to secure invariant representations by either harnessing domain expertise or leveraging data from multiple domains. In this paper, we introduce a novel approach that involves acquiring Causal Markov Blanket (CMB) representations to improve prediction performance in the face of distribution shifts. Causal Markov Blanket representations comprise the direct causes and effects of the target variable. Theoretical analyses have demonstrated their capacity to harbor maximum information about the target, resulting in minimal Bayes error during prediction. To elaborate, our approach commences with the introduction of a novel structural causal model (SCM) equipped with latent representations, designed to capture the underlying causal mechanisms governing the data generation process. Subsequently, we propose a CMB representation learning framework that derives representations conforming to the proposed SCM. In comparison to state-of-the-art domain generalization methods, our approach exhibits robustness and adaptability under distribution shifts. Please refer to the [PDF]

Causal Representation Learning and Inference for Generalizable Cross-Domain Predictions

This research addresses the poor cross domain generalization issue for machine learning and computer vision tasks. Current methods utilize data from multiple domains and seek to transfer invariant representations to new and unseen domains. This paper proposes to perform causal inference on a transportable, invariant interventional distribution to improve prediction performance under distribution shifts. To do so, we first propose an identifiable structural causal model (SCM) to capture the underlying causal mechanism that underpins the data generation process. Subject to the proposed SCM model, we then introduce a latent representation learning framework, allowing us to discover latent variables and capture the underlying data generation mechanisms. Next, we propose an inference procedure to estimate the invariant, transportable interventional distribution that can account for confounding effects between input and label. Furthermore, we empirically demonstrate the robustness of our proposed method under distribution shifts across multiple benchmark real datasets. Empirical results show that our proposed method outperforms the majority of domain generalization baselines, achieving state-of-the-art performance.

Publications

  • Naiyu Yin, Hanjing Wang, Tian Gao, Amit Dhurandhar, Qiang Ji. Causal Markov Blankett Representation Learning for Out-of-distribution Generalization. Causal Representation Learning Workshop at NeurIPS, 2023. [PDF]